from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker

def initialize_rng_states(seed=1234):
    tracker = get_cuda_rng_tracker()
    
    if not hasattr(tracker, 'states') or 'model-parallel-rng' not in tracker.states:
        print("Adding missing model-parallel-rng state with seed:", seed)
        tracker.add('model-parallel-rng', seed=seed)
    else:
        print("model-parallel-rng state already exists")
    
    return tracker